In [1]:
# Define experiment parameters
year = "201516"
target_col = "white_collar"  # 'white_collar', 'blue_collar', 'has_occ'
sample_weight_col = 'women_weight'
In [2]:
# Define resource utilization parameters
random_state = 42
n_jobs_clf = 16
n_jobs_cv = 4
cv_folds = 5
In [3]:
import numpy as np
np.random.seed(random_state)

import pandas as pd
pd.set_option('display.max_columns', 500)

import matplotlib.pylab as pl

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

from sklearn.utils.class_weight import compute_class_weight

import lightgbm
from lightgbm import LGBMClassifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import StratifiedKFold

import shap

import pickle
from joblib import dump, load

Prepare Dataset

In [4]:
# Load dataset
dataset = pd.read_csv(f"data/women_work_data_{year}.csv")
print("Loaded dataset: ", dataset.shape)
dataset.head()
Loaded dataset:  (111398, 26)
Out[4]:
Unnamed: 0 case_id_str line_no country_code cluster_no hh_no state wealth_index hh_religion caste women_weight women_anemic obese_female urban freq_tv age occupation years_edu hh_members no_children_below5 white_collar blue_collar no_occ has_occ year total_children
0 8 1000117.0 2 IA6 10001 17 andaman and nicobar islands middle hindu NaN 0.191636 1.0 0.0 1.0 3.0 23.0 0.0 10.0 2.0 0.0 0.0 0.0 1.0 0.0 2015.0 0.0
1 9 1000120.0 1 IA6 10001 20 andaman and nicobar islands richer hindu none of above 0.191636 0.0 0.0 1.0 3.0 35.0 8.0 8.0 3.0 0.0 0.0 1.0 0.0 1.0 2015.0 2.0
2 11 1000129.0 2 IA6 10001 29 andaman and nicobar islands richest muslim other backward class 0.191636 1.0 0.0 1.0 3.0 46.0 0.0 12.0 3.0 0.0 0.0 0.0 1.0 0.0 2015.0 2.0
3 12 1000129.0 3 IA6 10001 29 andaman and nicobar islands richest muslim other backward class 0.191636 1.0 0.0 1.0 3.0 17.0 0.0 11.0 3.0 0.0 0.0 0.0 1.0 0.0 2015.0 0.0
4 13 1000130.0 2 IA6 10001 30 andaman and nicobar islands richer christian scheduled caste 0.191636 1.0 1.0 1.0 3.0 30.0 0.0 8.0 5.0 0.0 0.0 0.0 1.0 0.0 2015.0 3.0
In [5]:
# See distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0.0    105988
1.0      5410
Name: white_collar, dtype: int64
In [6]:
# Drop samples where the target is missing
dataset.dropna(axis=0, subset=[target_col, sample_weight_col], inplace=True)
print("Drop missing targets: ", dataset.shape)
Drop missing targets:  (111398, 26)
In [7]:
# Drop samples where age < 21
dataset = dataset[dataset['age'] >= 21]
print("Drop under-21 samples: ", dataset.shape)
Drop under-21 samples:  (86825, 26)
In [8]:
# See new distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0.0    81831
1.0     4994
Name: white_collar, dtype: int64
In [9]:
# Post-processing

# Group SC/ST castes together
dataset['caste'][dataset['caste'] == 'scheduled caste'] = 'sc/st'
dataset['caste'][dataset['caste'] == 'scheduled tribe'] = 'sc/st'
if year == "200506":
    dataset['caste'][dataset['caste'] == '9'] = "don\'t know"

# Fix naming for General caste
dataset['caste'][dataset['caste'] == 'none of above'] = 'general'

if year == "201516":
    # Convert wealth index from str to int values
    wi_dict = {'poorest': 0, 'poorer': 1, 'middle': 2, 'richer': 3, 'richest': 4}
    dataset['wealth_index'] = [wi_dict[wi] for wi in dataset['wealth_index']]
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:5: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  """
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  # Remove the CWD from sys.path while we load stuff.
In [10]:
# Define feature columns
x_cols_categorical = ['state', 'hh_religion', 'caste']
x_cols_binary = ['urban', 'women_anemic', 'obese_female']
x_cols_numeric = ['age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
x_cols = x_cols_categorical + x_cols_binary + x_cols_numeric
print("Feature columns:\n", x_cols)
Feature columns:
 ['state', 'hh_religion', 'caste', 'urban', 'women_anemic', 'obese_female', 'age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
In [11]:
# Drop samples with missing values in feature columns
dataset.dropna(axis=0, subset=x_cols, inplace=True)
print("Drop missing feature value rows: ", dataset.shape)
Drop missing feature value rows:  (81816, 26)
In [12]:
# Separate target column
targets = dataset[target_col]
# Separate sampling weight column
sample_weights = dataset[sample_weight_col]
# Drop columns which are not part of features
dataset.drop(columns=[col for col in dataset.columns if col not in x_cols], axis=1, inplace=True)
print("Drop extra columns: ", dataset.shape)
Drop extra columns:  (81816, 13)
In [13]:
# Obtain one-hot encodings for the caste column
dataset = pd.get_dummies(dataset, columns=['caste'])
x_cols_categorical.remove('caste')  # Remove 'caste' from categorical variables list
print("Caste to one-hot: ", dataset.shape)
Caste to one-hot:  (81816, 16)
In [15]:
dataset_display = dataset.copy()
dataset_display.columns = ['State', 'Wealth Index', 'Hh. Religion', 'Anemic', 'Obese',
                           'Residence Type', 'Freq. of TV', 'Age', 'Yrs. of Education', 'Hh. Members',
                           'Children Below 5', 'Total Children', 'Unknown Caste',
                           'General Caste', 'OBC Caste', 'Sc/St Caste']
print("Create copy for visualization: ", dataset_display.shape)
dataset_display.head()
Create copy for visualization:  (81816, 16)
Out[15]:
State Wealth Index Hh. Religion Anemic Obese Residence Type Freq. of TV Age Yrs. of Education Hh. Members Children Below 5 Total Children Unknown Caste General Caste OBC Caste Sc/St Caste
1 andaman and nicobar islands 3 hindu 0.0 0.0 1.0 3.0 35.0 8.0 3.0 0.0 2.0 0 1 0 0
2 andaman and nicobar islands 4 muslim 1.0 0.0 1.0 3.0 46.0 12.0 3.0 0.0 2.0 0 0 1 0
4 andaman and nicobar islands 3 christian 1.0 1.0 1.0 3.0 30.0 8.0 5.0 0.0 3.0 0 0 0 1
5 andaman and nicobar islands 3 christian 1.0 0.0 1.0 3.0 21.0 12.0 5.0 0.0 0.0 0 0 0 1
7 andaman and nicobar islands 4 hindu 1.0 1.0 1.0 3.0 40.0 8.0 2.0 0.0 2.0 0 1 0 0
In [16]:
# Obtain integer encodings for other categorical features
for col in x_cols_categorical:
    dataset[col] = pd.factorize(dataset[col])[0]
print("Categoricals to int encodings: ", dataset.shape)
Categoricals to int encodings:  (81816, 16)
In [17]:
dataset.head()
Out[17]:
state wealth_index hh_religion women_anemic obese_female urban freq_tv age years_edu hh_members no_children_below5 total_children caste_don't know caste_general caste_other backward class caste_sc/st
1 0 3 0 0.0 0.0 1.0 3.0 35.0 8.0 3.0 0.0 2.0 0 1 0 0
2 0 4 1 1.0 0.0 1.0 3.0 46.0 12.0 3.0 0.0 2.0 0 0 1 0
4 0 3 2 1.0 1.0 1.0 3.0 30.0 8.0 5.0 0.0 3.0 0 0 0 1
5 0 3 2 1.0 0.0 1.0 3.0 21.0 12.0 5.0 0.0 0.0 0 0 0 1
7 0 4 0 1.0 1.0 1.0 3.0 40.0 8.0 2.0 0.0 2.0 0 1 0 0
In [18]:
# Create Training, Validation and Test sets
X_train, X_test, Y_train, Y_test, W_train, W_test = train_test_split(dataset, targets, sample_weights, test_size=0.05, random_state=random_state, stratify=targets)
# X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(X_train, Y_train, W_train, test_size=0.1)
print("Training set: ", X_train.shape, Y_train.shape, W_train.shape)
# print("Validation set: ", X_val.shape, Y_val.shape, W_val.shape)
print("Test set: ", X_test.shape, Y_test.shape, W_test.shape)
train_cw = compute_class_weight("balanced", classes=np.unique(Y_train), y=Y_train)
print("Class weights: ", train_cw)
Training set:  (77725, 16) (77725,) (77725,)
Test set:  (4091, 16) (4091,) (4091,)
Class weights:  [0.53069822 8.64379448]

Build LightGBM Classifier

In [19]:
# # Define LightGBM Classifier
# model = LGBMClassifier(boosting_type='gbdt', 
#                        feature_fraction=0.8,  
#                        learning_rate=0.01,
#                        max_bins=64,
#                        max_depth=-1,
#                        min_child_weight=0.001,
#                        min_data_in_leaf=50,
#                        min_split_gain=0.0,
#                        num_iterations=1000,
#                        num_leaves=64,
#                        reg_alpha=0,
#                        reg_lambda=1,
#                        subsample_for_bin=200000,
#                        is_unbalance=True,
#                        random_state=random_state, 
#                        n_jobs=n_jobs_clf, 
#                        silent=True, 
#                        importance_type='split')
In [20]:
# # Fit model on training set
# model.fit(X_train, Y_train, sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])
In [21]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [22]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-model.joblib')
# del model
In [23]:
# # Define hyperparameter grid
# param_grid = {
#     'num_leaves': [8, 32, 64],
#     'min_data_in_leaf': [10, 20, 50],
#     'max_depth': [-1], 
#     'learning_rate': [0.01, 0.1], 
#     'num_iterations': [1000, 3000, 5000], 
#     'subsample_for_bin': [200000],
#     'min_split_gain': [0.0], 
#     'min_child_weight': [0.001],
#     'feature_fraction': [0.8, 1.0], 
#     'reg_alpha': [0], 
#     'reg_lambda': [0, 1],
#     'max_bin': [64, 128, 255]
# }
In [24]:
# # Define LightGBM Classifier
# clf = LGBMClassifier(boosting_type='gbdt',
#                      objective='binary', 
#                      is_unbalance=True,
#                      random_state=random_state,
#                      n_jobs=n_jobs_clf, 
#                      silent=True, 
#                      importance_type='split')

# # Define K-fold cross validation splitter
# kfold = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

# # Perform grid search
# model = GridSearchCV(clf, param_grid=param_grid, scoring='f1', n_jobs=n_jobs_cv, cv=kfold, refit=True, verbose=3)
# model.fit(X_train, Y_train, 
#           sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])

# print('\n All results:')
# print(model.cv_results_)
# print('\n Best estimator:')
# print(model.best_estimator_)
# print('\n Best hyperparameters:')
# print(model.best_params_)
In [25]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions, average='micro'))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [26]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-gridsearch.joblib')
# del model

Load LightGBM Classifier

In [27]:
# model = load(f'models/{target_col}-{year}-model.joblib')
model = load(f'models/{target_col}-{year}-gridsearch.joblib').best_estimator_
In [28]:
# Sanity check: Make predictions on Test set
predictions = model.predict(X_test)
print(accuracy_score(Y_test, predictions))
print(f1_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
print(classification_report(Y_test, predictions))
0.8530921535076998
0.30520231213872834
[[3358  496]
 [ 105  132]]
              precision    recall  f1-score   support

         0.0       0.97      0.87      0.92      3854
         1.0       0.21      0.56      0.31       237

   micro avg       0.85      0.85      0.85      4091
   macro avg       0.59      0.71      0.61      4091
weighted avg       0.93      0.85      0.88      4091

In [29]:
# Overfitting check: Make predictions on Train set
predictions = model.predict(X_train)
print(accuracy_score(Y_train, predictions))
print(f1_score(Y_train, predictions))
print(confusion_matrix(Y_train, predictions))
print(classification_report(Y_train, predictions))
0.8986426503698939
0.510013683293942
[[65747  7482]
 [  396  4100]]
              precision    recall  f1-score   support

         0.0       0.99      0.90      0.94     73229
         1.0       0.35      0.91      0.51      4496

   micro avg       0.90      0.90      0.90     77725
   macro avg       0.67      0.90      0.73     77725
weighted avg       0.96      0.90      0.92     77725


Visualizations/Explainations

Note that these plot just explain how the XGBoost model works, not nessecarily how reality works. Since the XGBoost model is trained from observational data, it is not nessecarily a causal model, and so just because changing a factor makes the model's prediction of winning go up, does not always mean it will raise your actual chances.

In [30]:
# print the JS visualization code to the notebook
shap.initjs()

What makes a measure of feature importance good or bad?

  1. Consistency: Whenever we change a model such that it relies more on a feature, then the attributed importance for that feature should not decrease.
  2. Accuracy. The sum of all the feature importances should sum up to the total importance of the model. (For example if importance is measured by the R² value then the attribution to each feature should sum to the R² of the full model)

If consistency fails to hold, then we can’t compare the attributed feature importances between any two models, because then having a higher assigned attribution doesn’t mean the model actually relies more on that feature.

If accuracy fails to hold then we don’t know how the attributions of each feature combine to represent the output of the whole model. We can’t just normalize the attributions after the method is done since this might break the consistency of the method.

Using Tree SHAP for interpretting the model

In [31]:
explainer = shap.TreeExplainer(model)
# shap_values = explainer.shap_values(dataset)
shap_values = pickle.load(open(f'res/{target_col}-{year}-shapvals.obj', 'rb'))
In [32]:
# Visualize a single prediction
shap.force_plot(explainer.expected_value, shap_values[0,:], dataset_display.iloc[0,:])
Out[32]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

If we take many explanations such as the one shown above, rotate them 90 degrees, and then stack them horizontally, we can see explanations for an entire dataset (in the notebook this plot is interactive):

In [33]:
# Visualize many predictions
subsample = np.random.choice(len(dataset), 1000)  # Take random sub-sample
shap.force_plot(explainer.expected_value, shap_values[subsample,:], dataset_display.iloc[subsample,:])
Out[33]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Summary Plots

In [34]:
for col, sv in zip(dataset.columns, np.abs(shap_values).mean(0)):
    print(f"{col} - {sv}")
state - 0.44078384351160227
wealth_index - 0.2298890692702214
hh_religion - 0.16432921411524695
women_anemic - 0.06989211593867536
obese_female - 0.05931428067916626
urban - 0.17881266952506208
freq_tv - 0.1358988650924295
age - 0.393430730763656
years_edu - 0.7916145903212334
hh_members - 0.29093601341278735
no_children_below5 - 0.15709428780639415
total_children - 0.29410172320955114
caste_don't know - 0.003097509005274771
caste_general - 0.05973713325799156
caste_other backward class - 0.048845476642463324
caste_sc/st - 0.05836943714216214
In [35]:
shap.summary_plot(shap_values, dataset, plot_type="bar")

The above figure shows the global mean(|Tree SHAP|) method applied to our model.

The x-axis is essentially the average magnitude change in model output when a feature is “hidden” from the model (for this model the output has log-odds units). “Hidden” means integrating the variable out of the model. Since the impact of hiding a feature changes depending on what other features are also hidden, Shapley values are used to enforce consistency and accuracy.

However, since we now have individualized explanations for every person in our dataset, to get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low):

In [36]:
shap.summary_plot(shap_values, dataset_display)
  • Every person has one dot on each row.
  • The x position of the dot is the impact of that feature on the model’s prediction for the person.
  • The color of the dot represents the value of that feature for the customer. Categorical variables are colored grey.
  • Dots that don’t fit on the row pile up to show density (since our dataset is large).
  • Since the XGBoost model has a logistic loss the x-axis has units of log-odds (Tree SHAP explains the change in the margin output of the model).

How to use this: We can make analysis similar to the blog post for interpretting our models.


SHAP Dependence Plots

Next, to understand how a single feature effects the output of the model we can plot the SHAP value of that feature vs. the value of the feature for all the examples in a dataset. SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples.

SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions. One the benefits of SHAP dependence plots over traditional partial dependence plots is this ability to distigush between between models with and without interaction terms. In other words, SHAP dependence plots give an idea of the magnitude of the interaction terms through the vertical variance of the scatter plot at a given feature value.

Good example of using Dependency Plots: https://slundberg.github.io/shap/notebooks/League%20of%20Legends%20Win%20Prediction%20with%20XGBoost.html

Plots for 'age'

In [37]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children'),
         ('hh_religion', 'age'),
         ('state', 'age')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: age
Feature: state, Interaction Feature: age

Plots for 'wealth_index'

In [38]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children'),
         ('hh_religion', 'wealth_index'),
         ('state', 'wealth_index')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: wealth_index
Feature: state, Interaction Feature: wealth_index

Plots for 'years_edu'

In [39]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children'),
         ('hh_religion', 'years_edu'),
         ('state', 'years_edu')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: years_edu
Feature: state, Interaction Feature: years_edu

Plots for 'caste_sc/st'

In [40]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children'),
         ('hh_religion', 'caste_sc/st'),
         ('state', 'caste_sc/st')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_sc/st
Feature: state, Interaction Feature: caste_sc/st

Plots for 'caste_general'

In [41]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
         ('hh_religion', 'caste_general'),
         ('state', 'caste_general')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_general
Feature: state, Interaction Feature: caste_general

Visualizing Bar/Summary plots split by age bins

In [42]:
bins = [(21,25), (26,30), (31,35), (36,40), (41,45), (46,50)]

for low, high in bins:
    # Sample dataset by age range
    dataset_sample = dataset[(dataset.age > low) & (dataset.age <= high)]
    dataset_display_sample = dataset_display[(dataset.age > low) & (dataset.age <= high)]
    targets_sample = targets[(dataset.age > low) & (dataset.age <= high)]
    shap_values_sample = shap_values[(dataset.age > low) & (dataset.age <= high)]
    
    print("\nAge Range: {} - {} years".format(low, high))
    print("Sample size: {}\n".format(len(dataset_sample)))
    
    for col, sv in zip(dataset_sample.columns, np.abs(shap_values_sample).mean(0)):
        print(f"{col} - {sv}")
    
    # Summary plots
    shap.summary_plot(shap_values_sample, dataset_sample, plot_type="bar")
    shap.summary_plot(shap_values_sample, dataset_display_sample)
Age Range: 21 - 25 years
Sample size: 14034

state - 0.4311934364755243
wealth_index - 0.2424483837932801
hh_religion - 0.13868080571189895
women_anemic - 0.08050209773835065
obese_female - 0.04936720758204492
urban - 0.2148206572802434
freq_tv - 0.16164200534101056
age - 0.6000123082106207
years_edu - 0.8582757309801283
hh_members - 0.3424109257309364
no_children_below5 - 0.22766418874682465
total_children - 0.24337616166895387
caste_don't know - 0.003703992506502956
caste_general - 0.04803600926800763
caste_other backward class - 0.05438717473508545
caste_sc/st - 0.0698214803978477
Age Range: 26 - 30 years
Sample size: 13706

state - 0.4201670229677057
wealth_index - 0.226149923003438
hh_religion - 0.16171370453338424
women_anemic - 0.06826909691797922
obese_female - 0.06821288996160167
urban - 0.18115518270672787
freq_tv - 0.1282616029624979
age - 0.17617979005723614
years_edu - 0.7987156344690098
hh_members - 0.29268866296869456
no_children_below5 - 0.1687240664435448
total_children - 0.19964121674582727
caste_don't know - 0.004696254182821685
caste_general - 0.05780455248446113
caste_other backward class - 0.04370163180586437
caste_sc/st - 0.056653155262532236
Age Range: 31 - 35 years
Sample size: 12526

state - 0.43514525039973895
wealth_index - 0.22323700971271102
hh_religion - 0.17871876650885135
women_anemic - 0.058287169886163726
obese_female - 0.06006366465696789
urban - 0.15540151324364546
freq_tv - 0.13077046732923375
age - 0.30367518394304205
years_edu - 0.768903033842897
hh_members - 0.2656193403222048
no_children_below5 - 0.14551777960285378
total_children - 0.24463444298126125
caste_don't know - 0.002561669572818768
caste_general - 0.05347679981245632
caste_other backward class - 0.0476053836095569
caste_sc/st - 0.07057875164932456
Age Range: 36 - 40 years
Sample size: 11283

state - 0.44729303526168246
wealth_index - 0.2473032008028458
hh_religion - 0.18018081446841142
women_anemic - 0.05908766300199154
obese_female - 0.05694012976059111
urban - 0.14936164872732102
freq_tv - 0.1319397168779665
age - 0.40761850368263564
years_edu - 0.7490575026842256
hh_members - 0.27600957218837296
no_children_below5 - 0.12536721355847222
total_children - 0.3170210742272066
caste_don't know - 0.002452497877200384
caste_general - 0.060783849356771794
caste_other backward class - 0.04115069326058206
caste_sc/st - 0.05518705274975287
Age Range: 41 - 45 years
Sample size: 10087

state - 0.48383750039526835
wealth_index - 0.22270650587539165
hh_religion - 0.18167773166709394
women_anemic - 0.06908956225665765
obese_female - 0.055723050255578485
urban - 0.17539175521600367
freq_tv - 0.13094860219938798
age - 0.370248899263508
years_edu - 0.7322259787355042
hh_members - 0.2863277700508391
no_children_below5 - 0.10741897380360971
total_children - 0.4393305088255433
caste_don't know - 0.0020607356759359905
caste_general - 0.08105602664832298
caste_other backward class - 0.04648424053773165
caste_sc/st - 0.045575844139291936
Age Range: 46 - 50 years
Sample size: 6013

state - 0.4459999274374072
wealth_index - 0.22822708360994942
hh_religion - 0.15022375214833839
women_anemic - 0.1072409971173746
obese_female - 0.05601420749059834
urban - 0.19524910574308038
freq_tv - 0.1356365629514302
age - 0.3953982249918827
years_edu - 0.765042541967817
hh_members - 0.2801059546518702
no_children_below5 - 0.11746852407920025
total_children - 0.4261548424971268
caste_don't know - 0.001853749782652308
caste_general - 0.07780684787878901
caste_other backward class - 0.06827881454282964
caste_sc/st - 0.04391783667958432

SHAP Interaction Values

SHAP interaction values are a generalization of SHAP values to higher order interactions.

The model returns a matrix for every prediction, where the main effects are on the diagonal and the interaction effects are off-diagonal. The main effects are similar to the SHAP values you would get for a linear model, and the interaction effects captures all the higher-order interactions are divide them up among the pairwise interaction terms.

Note that the sum of the entire interaction matrix is the difference between the model's current output and expected output, and so the interaction effects on the off-diagonal are split in half (since there are two of each). When plotting interaction effects the SHAP package automatically multiplies the off-diagonal values by two to get the full interaction effect.

In [43]:
# Sample from dataset based on sample weights
dataset_ss = dataset.sample(10000, weights=sample_weights, random_state=random_state)
print(dataset_ss.shape)
dataset_display_ss = dataset_display.loc[dataset_ss.index]
print(dataset_display_ss.shape)
(10000, 16)
(10000, 16)
In [44]:
# Compute SHAP interaction values (time consuming)
# shap_interaction_values = explainer.shap_interaction_values(dataset_ss)
shap_interaction_values = pickle.load(open(f'res/{target_col}-{year}-shapints.obj', 'rb'))
In [45]:
shap.summary_plot(shap_interaction_values, dataset_display_ss, max_display=15)

Heatmap of SHAP Interaction Values

In [46]:
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i,i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds,:][:,inds]
pl.figure(figsize=(12,12))
pl.imshow(tmp2)
pl.yticks(range(tmp2.shape[0]), dataset_display_ss.columns[inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), dataset_display_ss.columns[inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.show()

SHAP Interaction Value Dependence Plots

Running a dependence plot on the SHAP interaction values a allows us to separately observe the main effects and the interaction effects.

Below we plot the main effects for age and some of the interaction effects for age. It is informative to compare the main effects plot of age with the earlier SHAP value plot for age. The main effects plot has no vertical dispersion because the interaction effects are all captured in the off-diagonal terms.

Good example of how to infer interesting stuff from interaction values: https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html

In [47]:
shap.dependence_plot(
    ("age", "age"), 
    shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)

Now we plot the interaction effects involving age (and other features after that). These effects capture all of the vertical dispersion that was present in the original SHAP plot but is missing from the main effects plot above.

Plots for 'age'

In [48]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children

Plots for 'wealth_index'

In [49]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children

Plots for 'years_edu'

In [50]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children

Plots for 'caste_sc/st'

In [51]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children

Plots for 'caste_general'

In [52]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children